Skip to content

[PyTorch][torch.compile] Remove process group from quantizers#3104

Open
pggPL wants to merge 6 commits into
NVIDIA:mainfrom
pggPL:remove_process_group_from_quantizers
Open

[PyTorch][torch.compile] Remove process group from quantizers#3104
pggPL wants to merge 6 commits into
NVIDIA:mainfrom
pggPL:remove_process_group_from_quantizers

Conversation

@pggPL

@pggPL pggPL commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Description

This makes adding torch.compile support much easier.
Move amax reduction process group handling out of quantizer state and pass it per quantization call instead. This avoids storing process groups inside quantizers while keeping deprecated stored-group fallback behavior for compatibility.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality not to work as expected)
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Pass amax_reduction_group through quantize/module call paths instead of storing it on quantizers.
  • Preserve deprecated constructor/state fallback for existing callers, excluding process groups from serialization.
  • Update FP8/NVFP4/MXFP8/blockwise tensor quantization paths and C++ bindings to resolve reduction groups per call.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@pggPL

pggPL commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

@greptile-apps

greptile-apps Bot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR refactors amax reduction process group handling out of persistent quantizer state and into per-call arguments, making the quantizer objects free of distributed state and enabling torch.compile graph capture without baking process groups into compiled artifacts.

  • Per-call group injection: set_quantizer_amax_reduction_group is called at every entry point (_linear_forward_impl, _linear_backward, _LayerNormLinear, _LayerNormMLP, BasicLinear) immediately before quantization, replacing the _customize_quantizers_nvfp4 / _customize_quantizers_float8_current_scaling side-channels that stored the group on the quantizer between calls.
  • FSDP2 tensor-level group: Float8Tensor and NVFP4Tensor now carry an amax_reduction_group class attribute (default None) that fsdp_pre_all_gather populates; _set_data / update_quantized create a throwaway quantizer copy using the group, so the base quantizer stays group-free after each update.
  • torch.compile improvements in base.py: get_ub_is_fp8 gets @torch.compiler.assume_constant_result to avoid repeated UB queries inside compiled graphs, destroy_ub calls torch.compiler.reset() to invalidate stale baked constants, and the CustomRecipeState early-return now checks recipe identity (is) rather than just type, preventing stale quantizer reuse when different custom recipe instances are used.

Confidence Score: 5/5

Safe to merge; the refactoring correctly moves process groups out of persistent quantizer state and all existing amax-reduction code paths are covered by explicit per-call setup.

The logic change is well-contained: each module entry point explicitly sets or clears the amax reduction group immediately before quantization, and the Quantizer.quantize() cleanup prevents the group from leaking into result tensor copies. The identified concerns (source-quantizer state not explicitly cleared, layernorm_linear backward relying on persistence for the input quantizer, tensor-level group never cleared after FSDP use) are design trade-offs that work correctly for all current usage patterns. No data flow, scale-inverse, or AllReduce operations appear to be silently dropped or duplicated.

transformer_engine/pytorch/quantized_tensor.py (source-quantizer post-quantize state), transformer_engine/pytorch/module/layernorm_linear.py (implicit input-quantizer group persistence in backward), and transformer_engine/pytorch/tensor/float8_tensor.py / nvfp4_tensor.py (tensor-level amax_reduction_group lifetime).

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/linear.py Replaces _customize_quantizers_nvfp4 and in-module group storage with per-call set_quantizer_amax_reduction_group in both _linear_forward_impl and _linear_backward; bakes dgrad/wgrad_use_split_accumulator into LinearBwdArgs instead of re-querying the global recipe at backward time (removes graph break).
transformer_engine/pytorch/quantized_tensor.py Adds post-quantize cleanup to strip the amax reduction group from the result tensor's embedded quantizer copy; the SOURCE quantizer's state is intentionally not cleared here but relies on callers invoking set_quantizer_amax_reduction_group on every entry point.
transformer_engine/pytorch/tensor/float8_tensor.py Moves the FSDP2 amax reduction group from the quantizer to a tensor-level attribute (amax_reduction_group); _set_data creates a throwaway quantizer copy when the group is set; the group is set once in fsdp_pre_all_gather and never cleared, persisting for the tensor's lifetime.
transformer_engine/pytorch/tensor/nvfp4_tensor.py Adds tensor-level amax_reduction_group (mirroring Float8Tensor) and reads it inside update_quantized via a throwaway copy; group is set in fsdp_pre_all_gather and similarly never cleared after use.
transformer_engine/pytorch/module/base.py Adds get_ub_is_fp8 decorated with @torch.compiler.assume_constant_result to avoid repeated UB queries inside compiled graphs; destroy_ub now calls torch.compiler.reset() to invalidate stale baked constants; fixes CustomRecipeState early-return to check recipe identity, not just recipe type.
transformer_engine/pytorch/module/_common.py Adds set_quantizer_amax_reduction_group helper used at every call-site to transiently configure the group before quantization.
transformer_engine/pytorch/module/layernorm_linear.py Removes _customize_quantizers_nvfp4 and moves amax group setting per-call into the autograd Function; backward sets grad-output quantizer group explicitly but does not reset input_quantizer, relying on persistence from the forward pass.
transformer_engine/pytorch/module/layernorm_mlp.py Removes _customize_quantizers_nvfp4; moves FC1-input and FC2-grad-output amax group setting per-call into the autograd Function, symmetric with layernorm_linear.
transformer_engine/pytorch/ops/basic/basic_linear.py Removes static SP-group assignment from _customize_quantizers_*; explicitly calls set_quantizer_amax_reduction_group in both forward and backward, conditioned on the all-gather flag rather than querying parallel mode manually.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant Caller
    participant Module as Linear / LNLinear / LNMlp
    participant Helper as set_quantizer_amax_reduction_group
    participant Q as Quantizer (source)
    participant QImpl as quantize_impl
    participant Result as QuantizedTensor._quantizer (copy)

    Caller->>Module: "forward(inp, tp_group=...)"
    Module->>Helper: set_quantizer_amax_reduction_group(input_q, tp_group if SP+col else None)
    Helper->>Q: "q.with_amax_reduction = True / False"
    Module->>Q: q.quantize(tensor)
    Q->>QImpl: quantize_impl(tensor) — new QuantizedTensor
    QImpl-->>Q: result (with _quantizer copy)
    Q->>Result: if copy.with_amax_reduction — clear it
    Note over Q: SOURCE q still has with_amax_reduction=True

    Caller->>Module: "backward(grad, tp_group=...)"
    Module->>Helper: set_quantizer_amax_reduction_group(grad_out_q, tp_group if SP+row else None)
    Helper->>Q: "q.with_amax_reduction = True / False"
    Note over Module: _linear_backward also resets input_q explicitly
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant Caller
    participant Module as Linear / LNLinear / LNMlp
    participant Helper as set_quantizer_amax_reduction_group
    participant Q as Quantizer (source)
    participant QImpl as quantize_impl
    participant Result as QuantizedTensor._quantizer (copy)

    Caller->>Module: "forward(inp, tp_group=...)"
    Module->>Helper: set_quantizer_amax_reduction_group(input_q, tp_group if SP+col else None)
    Helper->>Q: "q.with_amax_reduction = True / False"
    Module->>Q: q.quantize(tensor)
    Q->>QImpl: quantize_impl(tensor) — new QuantizedTensor
    QImpl-->>Q: result (with _quantizer copy)
    Q->>Result: if copy.with_amax_reduction — clear it
    Note over Q: SOURCE q still has with_amax_reduction=True

    Caller->>Module: "backward(grad, tp_group=...)"
    Module->>Helper: set_quantizer_amax_reduction_group(grad_out_q, tp_group if SP+row else None)
    Helper->>Q: "q.with_amax_reduction = True / False"
    Note over Module: _linear_backward also resets input_q explicitly
Loading

Reviews (8): Last reviewed commit: "Carry amax reduction group on the Quanti..." | Re-trigger Greptile

Comment on lines +326 to +329
"""Quantize tensor"""
return self.quantize(tensor)
if amax_reduction_group is None:
return self.quantize(tensor)
return self.quantize(tensor, amax_reduction_group=amax_reduction_group)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The None guard here is redundant: self.quantize(tensor) and self.quantize(tensor, amax_reduction_group=None) are identical because quantize defaults the argument to None. The branch just adds noise.

Suggested change
"""Quantize tensor"""
return self.quantize(tensor)
if amax_reduction_group is None:
return self.quantize(tensor)
return self.quantize(tensor, amax_reduction_group=amax_reduction_group)
"""Quantize tensor"""
return self.quantize(tensor, amax_reduction_group=amax_reduction_group)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +200 to +203
@property
def rht_matrix(self) -> torch.Tensor:
"""RHT matrix (fetched from the process-global cache, not stored per quantizer)."""
return get_rht_matrix(self._with_random_sign_mask, torch.cuda.current_device())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Deserialization break for old pickled NVFP4Quantizer instances

rht_matrix is now a property that reads self._with_random_sign_mask, but _with_random_sign_mask is a new field that did not exist in pickled state produced before this change. When Python's default __setstate__ (i.e., self.__dict__.update(state)) loads an old pickle, _with_random_sign_mask is absent, so any access to the rht_matrix property raises AttributeError. A __setstate__ that infers _with_random_sign_mask from the old stored rht_matrix (or supplies a safe default) would preserve backward compatibility for serialized quantizers.

@pggPL

pggPL commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator Author

Blocked by FSDP bug, refactor in progress.

I plan to store .amax_reduction_group in QuantizedTensor.

@timmoon10 timmoon10 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be a design mistake. The amax reduction does not have a consistent meaning across recipes (including recipes where it doesn't make sense), and this change requires spilling out amax reduction logic into quantizer callsites (even where it doesn't make sense).

Can you go into more detail exactly why torch.compile doesn't work when quantizers have process groups? If we just want the quantizer to hold simple Python objects, maybe we can make the quantizer hold an int for the communicator ID. I envision something like:

class Float8CurrentScalingQuantizer(Quantizer):

    _communicator_cache = {}

    @property
    def amax_reduction_group(self):
        if self._amax_reduction_group_id is None:
            return None
        return Float8CurrentScalingQuantizer._communicator_cache[self._amax_reduction_group_id]

    @property.setter
    def amax_reduction_group(self, comm):
        if comm is None:
            self._amax_reduction_group_id = None
        self._amax_reduction_group_id = id(comm)
        Float8CurrentScalingQuantizer._communicator_cache[self._amax_reduction_group_id] = comm

I'm not sure how this would interact with checkpointing though.

dst: QuantizedTensor,
*,
noop_flag: Optional[torch.Tensor] = None,
amax_reduction_group: Optional[Any] = None, # pylint: disable=unused-argument

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I strongly oppose this API change. amax reduction is very recipe-specific. It has different meanings for different recipes (FP8 DS might reduce over the TP+DP group, FP8 CS might only reduce over the TP group) and it has no meaning for other recipes (MXFP8 and FP8 block scaling). Moving it into the generic API will leak recipe-specific information, defeating the point of a generic API.

@pggPL

pggPL commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch L1

pggPL and others added 4 commits June 15, 2026 16:40
…stants; fix SP memory leak; test suite hook-up

Wrap CommOverlapCore pybind11 methods that return compile-time constants
so torch.compile(fullgraph=True) can trace through them without graph
breaks:
- `is_fp8_ubuf()` → `ub_is_fp8()` / `get_ub_is_fp8()` in base.py;
  `_ub_is_fp8()` in gemm.py
- `with_cublasmp()` → `ub_is_cublasmp()` in base.py

All callers in linear.py, layernorm_linear.py, layernorm_mlp.py,
base.py, gemm.py, userbuffers_backward_linear.py and
userbuffers_forward_linear.py updated.

Fix quantized grad_output not being freed early for column-parallel SP
backward. Row-parallel SP already called clear_tensor_data(grad_output)
to release the gathered tensor; column-parallel SP quantizes grad_output
to Float8TensorStorage but never freed it before returning.  Under
torch.compile reduce-overhead this leaves 3 live pool tensors at
recording end and triggers "Detected 3 tensor(s) in the cudagraph pool
not tracked as outputs".  Extend the existing clear_tensor_data guard to
cover both parallel modes.

Fix custom-recipe quantizer state being re-initialised on every forward
call even when the recipe object has not changed. The existing early-exit
for CustomRecipeState was missing an identity check on the recipe object,
so any repeated call with the same recipe would bypass the early-return
and rebuild quantizers unnecessarily.  Add `if recipe_state.recipe is
recipe: return` to restore the intended caching behaviour.

Add test_torch_compile.py to L0_pytorch_unittest so the autocast and
existing compile tests run in CI.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…-accumulator booleans

LinearBwdArgs stored the entire FP8 recipe object so the backward could
extract fp8_gemm_dgrad.use_split_accumulator and
fp8_gemm_wgrad.use_split_accumulator at GEMM time.  Recipe objects hold
process-group references and are not serialisable as compile-time
constants, making them incompatible with torch.compile custom-op paths.

Replace fp8_recipe with two plain bool fields:
- dgrad_use_split_accumulator (default _2X_ACC_DGRAD)
- wgrad_use_split_accumulator (default _2X_ACC_WGRAD)

These are resolved once in _linear_setup_ctx and passed into the args
struct, so the backward consumes scalars instead of a live recipe object.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…t_result

get_ub_is_fp8 bakes is_fp8_ubuf() as a compile-time constant; without a
reset, destroy_ub + re-init with different FP8 settings would read stale
values until recompile. Only affects in-memory caches, not disk.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the remove_process_group_from_quantizers branch from e9097d6 to 948cd6d Compare June 16, 2026 12:23
pggPL and others added 2 commits June 16, 2026 16:32
ToyLinear now overrides get_quantizer_roles so CustomRecipeState doesn't hit
the no-roles warning, which graph-breaks under fullgraph=True. qfactory
dispatches on role.tensor_type instead of a pre-baked string key.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
The amax reduction process group is no longer stored persistently on a module
quantizer or on a tensor's quantizer. No C++ changes.

- TP sequence parallel: the group is set on the input/grad-output quantizer at
  point of use in the fwd/bwd impls (linear, layernorm_linear, layernorm_mlp,
  ops basic_linear), replacing the setup-time _customize_quantizers wiring.
- FSDP2: the group is stored on Float8Tensor/NVFP4Tensor (set in
  fsdp_pre_all_gather) and applied to a throwaway quantizer copy during the
  in-place re-quant (update_quantized / _set_data).
- quantize() strips the group off the output tensor's quantizer so it never
  persists on any tensor's quantizer (breaks flatten/pickle otherwise).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL pggPL force-pushed the remove_process_group_from_quantizers branch from b8c1bec to 6c9b986 Compare June 16, 2026 14:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants